#!/usr/bin/env python3
import numpy as np

# core: doubling schedule
def f_X_T(T, lam):
    """Perf ratio f(T)."""
    T = np.asarray(T, float)
    k = np.floor(np.log2(T / lam)).astype(int)
    return T / (lam * (2.0 ** (k - 1)))

def ell_X_T(T, lam):
    """Completed length ℓ(T)."""
    T = np.asarray(T, float)
    k = np.floor(np.log2(T / lam)).astype(int)
    return lam * (2.0 ** (k - 1))

def lam_for_jump(T_next, k):
    """λ s.t. T_next=λ·2^k."""
    return T_next / (2 ** k)

# weights for MAX/AVG on [y(1-δ), y(1+δ)]
def w_linear(T, y, d):
    """Triangular."""
    e = d * y
    T = np.asarray(T, float)
    w = np.maximum(0.0, 1.0 - np.abs(T - y) / e)
    w[(T < y - e) | (T > y + e)] = 0.0
    return w

def w_gauss(T, y, d, s_scale=4.0, trunc=True):
    """Gaussian window."""
    e = d * y
    s = e / s_scale
    T = np.asarray(T, float)
    w = np.exp(-0.5 * ((T - y) / s) ** 2)
    if trunc:
        lo, hi = y * (1 - d), y * (1 + d)
        w[(T < lo) | (T > hi)] = 0.0
    return w

def get_w(name):
    """Pick weight fn."""
    name = name.lower()
    if name == "linear":   return w_linear
    if name == "gaussian": return lambda T, y, d: w_gauss(T, y, d, s_scale=4.0, trunc=True)
    raise ValueError("weight_type ∈ {linear, gaussian}")

# μ for CVaR/expectations
def make_mu_grid(y, d, *, mu_type="gaussian", s_scale=4.0, k_sigma=5.0, n=1200):
    """Return (x, μ) normalized."""
    mu_type = mu_type.lower()
    eps = 1e-9
    if mu_type == "gaussian":
        s = (d * y) / s_scale
        lo, hi = max(eps, y - k_sigma * s), y + k_sigma * s
        x = np.linspace(lo, hi, n)
        mu = np.exp(-0.5 * ((x - y) / s) ** 2) / (s * np.sqrt(2 * np.pi))
    elif mu_type == "linear":
        lo, hi = max(eps, y * (1 - d)), y * (1 + d)
        x = np.linspace(lo, hi, n)
        mu = w_linear(x, y, d)
    else:
        raise ValueError("mu_type ∈ {gaussian, linear}")
    Z = np.trapz(mu, x)
    mu = mu / Z if Z > 0 else np.zeros_like(mu)
    return x, mu

# λ search on [1,2)
def _d_max(lam, y, d, wfn):
    T = np.linspace(y * (1 - d), y * (1 + d), 1200)
    return float(np.max(np.abs(f_X_T(T, lam) - 2.0) * wfn(T, y, d)))

def _d_avg(lam, y, d, wfn):
    T = np.linspace(y * (1 - d), y * (1 + d), 1200)
    return float(np.mean(np.abs(f_X_T(T, lam) - 2.0) * wfn(T, y, d)))

def _k_for_cvar(lam, y, d):
    a = max(1e-12, y - d * y)
    return int(np.floor(np.log2(a / lam)))

def _A_for_cvar(lam, k, y, d, x, mu):
    """A=∫ μ on [y-δy, λ·2^{k+1}] ∩ grid."""
    lo, hi = y * (1 - d), lam * (2 ** (k + 1))
    L, H = float(x[0]), float(x[-1])
    s, t = max(lo, L), min(hi, H)
    if s >= t: return 0.0
    m = (x >= s) & (x <= t)
    return float(np.trapz(mu[m], x[m])) if np.any(m) else 0.0

def _cvar_val(lam, y, d, a, x, mu):
    """CVaR objective value (maximize)."""
    k = _k_for_cvar(lam, y, d)
    A = _A_for_cvar(lam, k, y, d, x, mu)
    if a == 0:
        return lam * (2 ** (k - 1)) * (2 - A)
    v1 = (lam * (2 ** (k - 1)) / (1 - a)) * (2 * (1 - a) - A)
    v2 = lam * (2 ** (k - 1))
    return max(v1, v2)

def find_lambda(y, d, objective, *, alpha=0.5, wfn=w_linear, x_mu=None, mu=None, n_grid=4000):
    """Return λ∈[1,2). objective∈{max,avg,cvar}."""
    lam_grid = np.linspace(1.0, 2.0, n_grid, endpoint=False)
    if objective == "max":
        vals = [_d_max(l, y, d, wfn) for l in lam_grid]
        return lam_grid[int(np.argmin(vals))]
    if objective == "avg":
        vals = [_d_avg(l, y, d, wfn) for l in lam_grid]
        return lam_grid[int(np.argmin(vals))]
    if objective == "cvar":
        if (x_mu is None) or (mu is None): raise ValueError("x_mu, mu required for CVaR.")
        vals = [_cvar_val(l, y, d, alpha, x_mu, mu) for l in lam_grid]
        return lam_grid[int(np.argmax(vals))]
    raise ValueError("objective ∈ {max, avg, cvar}")

# metrics
def avg_ratio(lam, y, d, n=600):
    """Avg f on [y(1-δ), y(1+δ)]."""
    T = np.linspace(y * (1 - d), y * (1 + d), n)
    return float(np.mean(f_X_T(T, lam)))

def E_ratio(lam, x, mu):
    """E[f] under μ."""
    return float(np.trapz(f_X_T(x, lam) * mu, x))

def E_length(lam, x, mu):
    """E[ℓ] under μ."""
    return float(np.trapz(ell_X_T(x, lam) * mu, x))

def win_rate(lam_alg, lam_base, y, d, n=600):
    """P[f_alg < f_base] on [y(1-δ), y(1+δ)]."""
    T = np.linspace(y * (1 - d), y * (1 + d), n)
    return float(np.mean(f_X_T(T, lam_alg) < f_X_T(T, lam_base)))

# experiment
def contract_experiment(
    *, n_trials=100, y_low=0.8e6, y_high=1.2e6, deltas=(0.2, 1/3, 0.4),
    weight_type="linear", mu_type="gaussian",
    seed=42, s_scale=4.0, k_sigma=5.0, alphas=(0.1, 0.5, 0.9),
    n_boot=800, n_grid_lambda=3000
):
    """Monte Carlo summaries + CIs."""
    rng = np.random.default_rng(seed)
    wfn = get_w(weight_type)
    algs = [("MAX", "max", None), ("AVG", "avg", None),
            (f"CVaR_{alphas[0]}", "cvar", alphas[0]),
            (f"CVaR_{alphas[1]}", "cvar", alphas[1]),
            (f"CVaR_{alphas[2]}", "cvar", alphas[2])]
    bases = ["PO", "deltaT"]
    OUT_S, OUT_CI, OUT_ALL = {}, {}, {}

    def ci_asym(x, conf=0.95):
        a = np.asarray(x, float)
        if a.size == 0: return float("nan"), (0.0, 0.0)
        m = float(np.mean(a))
        if a.size == 1: return m, (0.0, 0.0)
        idx = rng.integers(0, a.size, size=(n_boot, a.size))
        boots = np.mean(a[idx], axis=1)
        ql, qh = np.quantile(boots, [(1 - conf) / 2, 1 - (1 - conf) / 2])
        return m, (m - float(ql), float(qh) - m)

    for d in deltas:
        R = {"baselines_avg": {b: [] for b in bases},
             "baselines_Eratio": {b: [] for b in bases},
             "baselines_Elength": {b: [] for b in bases},
             "E_opt_ratio": [], "E_opt_length": []}
        for n, _, _ in algs:
            R[n] = {"avg_ALG": [], "E_ratio": [], "E_length": [], "imp_vs_PO": [], "imp_vs_deltaT": []}

        for _ in range(n_trials):
            y = rng.uniform(y_low, y_high)
            x_mu, mu = make_mu_grid(y, d, mu_type=mu_type, s_scale=s_scale, k_sigma=k_sigma, n=1200)

            R["E_opt_ratio"].append(2.0)
            R["E_opt_length"].append(float(np.trapz(0.5 * x_mu * mu, x_mu)))

            lam_MAX = find_lambda(y, d, "max", wfn=wfn, x_mu=x_mu, mu=mu, n_grid=n_grid_lambda)
            k0 = int(np.floor(np.log2(y / lam_MAX)))
            lam_PO = lam_for_jump(y, k0)
            lam_dT = lam_for_jump((1 - d) * y, k0)

            R["baselines_avg"]["PO"].append(avg_ratio(lam_PO, y, d))
            R["baselines_avg"]["deltaT"].append(avg_ratio(lam_dT, y, d))
            R["baselines_Eratio"]["PO"].append(E_ratio(lam_PO, x_mu, mu))
            R["baselines_Eratio"]["deltaT"].append(E_ratio(lam_dT, x_mu, mu))
            R["baselines_Elength"]["PO"].append(E_length(lam_PO, x_mu, mu))
            R["baselines_Elength"]["deltaT"].append(E_length(lam_dT, x_mu, mu))

            lam_map = {"MAX": lam_MAX}
            for n, obj, a in algs[1:]:
                lam_map[n] = find_lambda(y, d, obj, alpha=(a if obj == "cvar" else 0.5),
                                         wfn=wfn, x_mu=x_mu, mu=mu, n_grid=n_grid_lambda)

            for n, _, _ in algs:
                lam = lam_map[n]
                R[n]["avg_ALG"].append(avg_ratio(lam, y, d))
                R[n]["E_ratio"].append(E_ratio(lam, x_mu, mu))
                R[n]["E_length"].append(E_length(lam, x_mu, mu))

                kA = int(np.floor(np.log2(y / lam)))
                lam_PO_A = lam_for_jump(y, kA)
                lam_dT_A = lam_for_jump((1 - d) * y, kA)
                R[n]["imp_vs_PO"].append(win_rate(lam, lam_PO_A, y, d))
                R[n]["imp_vs_deltaT"].append(win_rate(lam, lam_dT_A, y, d))

        S, CI = {}, {}
        for b in bases:
            m, e = ci_asym(R["baselines_avg"][b]);     S[f"avg_ratio_{b}"] = m;  CI[f"avg_ratio_{b}"] = e
            m, e = ci_asym(R["baselines_Eratio"][b]);  S[f"E_ratio_{b}"] = m;    CI[f"E_ratio_{b}"] = e
            m, e = ci_asym(R["baselines_Elength"][b]); S[f"E_length_{b}"] = m;   CI[f"E_length_{b}"] = e
        m, e = ci_asym(R["E_opt_ratio"]);   S["E_opt_ratio"] = m;   CI["E_opt_ratio"] = e
        m, e = ci_asym(R["E_opt_length"]);  S["E_opt_length"] = m;  CI["E_opt_length"] = e
        for n, _, _ in algs:
            m, e = ci_asym(R[n]["avg_ALG"]);      S[f"avg_ratio_{n}"] = m;  CI[f"avg_ratio_{n}"] = e
            m, e = ci_asym(R[n]["E_ratio"]);      S[f"E_ratio_{n}"] = m;    CI[f"E_ratio_{n}"] = e
            m, e = ci_asym(R[n]["E_length"]);     S[f"E_length_{n}"] = m;   CI[f"E_length_{n}"] = e
            m, e = ci_asym(R[n]["imp_vs_PO"]);    S[f"time_{n}_lt_PO"] = m; CI[f"time_{n}_lt_PO"] = e
            m, e = ci_asym(R[n]["imp_vs_deltaT"]);S[f"time_{n}_lt_deltaT"] = m; CI[f"time_{n}_lt_deltaT"] = e

        OUT_S[d], OUT_CI[d], OUT_ALL[d] = S, CI, R

    return OUT_S, OUT_CI, OUT_ALL

# pretty print
def print_contract_results(S, CI, *, note=""):
    """Compact dump."""
    def fmt(d, k):
        m = S[d][k]; lo, hi = CI[d][k]; return f"{m:.4f}  +{hi:.4f}/-{lo:.4f}"
    for d in S.keys():
        print("\n" + "=" * 68)
        print(f"Contract Scheduling {note}  (delta={d:.3f}, y~U[0.8e6,1.2e6])")
        for b in ["PO", "deltaT"]:
            print(f"avg_ratio_{b:7s}: {fmt(d, f'avg_ratio_{b}')}")
        print(f"E_opt_ratio     : {fmt(d, 'E_opt_ratio')}")
        print(f"E_opt_length    : {fmt(d, 'E_opt_length')}")
        for b in ["PO", "deltaT"]:
            print(f"E_ratio_{b:7s}  : {fmt(d, f'E_ratio_{b}')}")
            print(f"E_length_{b:6s} : {fmt(d, f'E_length_{b}')}")
        for n in ["MAX", "AVG", "CVaR_0.1", "CVaR_0.5", "CVaR_0.9"]:
            print(f"{n:12s} avg_ratio : {fmt(d, f'avg_ratio_{n}')}")
            print(f"{n:12s} E_ratio  : {fmt(d, f'E_ratio_{n}')}")
            print(f"{n:12s} E_length : {fmt(d, f'E_length_{n}')}")
            print(f"{n:12s} time<PO  : {fmt(d, f'time_{n}_lt_PO')}")
            print(f"{n:12s} time<dT  : {fmt(d, f'time_{n}_lt_deltaT')}")

# example
if __name__ == "__main__":
    S, CI, _ = contract_experiment(
        n_trials=100,
        y_low=0.8e6, y_high=1.2e6,
        deltas=(0.2, 1/3, 0.4),
        weight_type="gaussian",
        mu_type="gaussian",
        seed=42,
        n_boot=800,
        n_grid_lambda=3000,
    )
    print_contract_results(S, CI, note="[w=gaussian, μ=gaussian]")
